package godb

import (
	"database/sql"
	"errors"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"github.com/spf13/cast"
	"github.com/xwb1989/sqlparser"
	"regexp"
	"strings"
	"time"
)

// MYSQL mysql操作对象
type MYSQL struct {
	conn
	mysqlQuery
	linkString string
}

// MYSQLConnect 链接mysql数据库，其中other参数代表链接字符串附加的配置信息
// 其中other="loc=Local&multiStatements=true"
//
//eg:mysql://lcfgly:wang93426@tcp(api.zhifangw.cn:3306)/rfid?loc=Local&multiStatements=true
func MYSQLConnect(host, username, password, db string, other ...string) (*MYSQL, error) {
	linkstring := username + ":" + password + "@tcp(" + host + ")/" + db
	if len(other) > 0 {
		linkstring += "?" + other[0]
	}
	result := &MYSQL{}
	sqlDB, err := sql.Open("mysql", linkstring)
	if err != nil {
		return nil, err
	}
	sqlDB.SetConnMaxIdleTime(30 * time.Minute) //半个小时后重置链接
	result.SetSQLDB(sqlDB)
	result.SetDataBaseName(db) //记录数据库名称,表名格式化会用到
	result.exec = sqlDB
	result.driverName = MYSQLDriver
	result.linkString = linkstring
	return result, nil
}

func (m *MYSQL) Transaction(t TransactionFunc, option ...*sql.TxOptions) error {
	mtx := &mysqlTx{}
	mtx.SetDataBaseName(m.dbname)
	return m.newTransaction(t, mtx.GetTxSQL, option...)
}

func (m *MYSQL) Close() {
	m.conn.Close()
	m.exec = nil
}

type mysqlQuery struct {
	query
}

func (m *mysqlQuery) QueryRow(sql string, args ...interface{}) *QueryResult {
	if m.checkConn {
		if err := m.connect(); err != nil {
			return ErrQueryResult(err, m.dbname, sql, args)
		}
	}
	stmt, err := sqlparser.Parse(sql)
	if err != nil {
		return ErrQueryResult(fmt.Errorf("sql语句解析错误:%w", err), m.dbname, sql, args)
	}
	switch stmt := stmt.(type) {
	case *sqlparser.Select:
		if stmt.Limit != nil {
			stmt.Limit.Rowcount = sqlparser.NewIntVal([]byte("1"))
		} else {
			stmt.SetLimit(&sqlparser.Limit{
				Rowcount: sqlparser.NewIntVal([]byte("1")),
			})
		}
	default:
		return ErrQueryResult(errors.New("只支持select语句"), m.dbname, sql, args)
	}
	buf := sqlparser.NewTrackedBuffer(nil)
	stmt.Format(buf)
	sql = regexp.MustCompile("(\\:v\\d+)").ReplaceAllString(buf.String(), "?")
	return m.QueryRows(sql, args...)
}

func (m *mysqlQuery) QueryWithPage(sql string, page *PageObj, args ...interface{}) *QueryResult {
	if page == nil {
		return m.QueryRows(sql, args...)
	}
	sqlInfo, err := mysqlSelectSQLParser(sql)
	if err != nil {
		return ErrQueryResult(err, m.dbname, sql, args)
	}
	sqlBuilder := strings.Builder{}
	if page.Total < 1 {
		sqlBuilder.WriteString("SELECT count(0) num FROM ")
		sqlBuilder.WriteString(sqlInfo.table)
		sqlBuilder.WriteString(sqlInfo.where)
		sqlBuilder.WriteString(sqlInfo.groupBy)
		sqlBuilder.WriteString(sqlInfo.having)
		result := m.QueryRows(sqlBuilder.String(), args...)
		count := cast.ToInt64(cast.ToString(result.Get("num")))
		page.SetTotal(count)
		sqlBuilder.Reset()
	}
	currentpage := 0
	if page.Page-1 > 0 {
		currentpage = page.Page - 1
	}
	if page.Total < 1 {
		return NewRowsResult(nil, sql, args)
	}
	sqlBuilder.WriteString("SELECT ")
	sqlBuilder.WriteString(sqlInfo.selectColumns)
	sqlBuilder.WriteString(" FROM ")
	sqlBuilder.WriteString(sqlInfo.table)
	sqlBuilder.WriteString(sqlInfo.where)
	sqlBuilder.WriteString(sqlInfo.groupBy)
	sqlBuilder.WriteString(sqlInfo.having)
	sqlBuilder.WriteString(sqlInfo.orderBy)
	sqlBuilder.WriteString(" LIMIT ")
	sqlBuilder.WriteString(cast.ToString(page.Rows * currentpage))
	sqlBuilder.WriteString(",")
	sqlBuilder.WriteString(cast.ToString(page.Rows))
	sql = sqlBuilder.String()
	return m.QueryRows(sql, args...)
}

type mysqlTx struct {
	mysqlQuery
	tx *sql.Tx
}

func (m *mysqlTx) GetTxSQL(tx *sql.Tx) TxSQL {
	m.tx = tx
	m.checkConn = false
	m.exec = tx
	return m
}

func (m *mysqlTx) GetTx() *sql.Tx {
	return m.tx
}
